from typing import Dict, Any, List, Optional, Union, Tuple
from hyperopt import fmin, tpe, rand, anneal, atpe, mix, Trials, STATUS_OK
from hypersense.optimizer.base_optimizer import BaseOptimizer
from functools import partial
from hyperopt.pyll.base import Apply, Literal
import time


class HyperoptOptimizer(BaseOptimizer):
    def __init__(
        self,
        algo: Optional[str] = "tpe",
        mix_components: Optional[List[Tuple[float, str]]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.trials = Trials()
        self.algo = self._resolve_algo(algo, mix_components)

        # warm-start (if points+rewards given)
        if self.points_to_evaluate and self.evaluated_rewards:
            for cfg, reward in zip(self.points_to_evaluate, self.evaluated_rewards):
                loss = self._convert_score(reward)
                self.trials.insert_trial_doc(
                    {
                        "misc": {"vals": {k: [v] for k, v in cfg.items()}},
                        "result": {"loss": loss, "status": STATUS_OK},
                        "tid": len(self.trials.trials),
                        "state": 2,  # FINISHED
                    }
                )
            self.trials.refresh()

    def _resolve_algo(self, name: str, mix_components: Optional[List[Tuple[float, str]]] = None):
        name = name.lower()

        algo_map = {
            "tpe": tpe.suggest,
            "rand": rand.suggest,
            "anneal": anneal.suggest,
            "atpe": atpe.suggest,
        }

        if name == "mix":
            if not mix_components:
                raise ValueError("When using algo='mix', you must provide mix_components=[(weight, algo_name)]")

            parsed_mix = []
            for weight, algo_name in mix_components:
                algo_name = algo_name.lower()
                if algo_name not in algo_map:
                    raise ValueError(f"Unknown mix component: {algo_name}")
                parsed_mix.append((weight, algo_map[algo_name]))

            return partial(mix.suggest, p_suggest=parsed_mix)

        elif name in algo_map:
            return algo_map[name]

        else:
            raise ValueError(f"Unsupported hyperopt algo: {name}")

    def _convert_score(self, score: Union[float, List[float]]) -> float:
        """Handle mode=max by converting score to a loss for hyperopt (which minimizes)."""
        if isinstance(score, list):
            score = score[0]
        return -score if self.mode == "max" else score

    def _objective_wrapper(self, config: Dict[str, Any]):
        """Hyperopt requires returning a dict with 'loss' and 'status'."""
        start_trial = time.time()
        score = self.objective_fn(config)
        elapsed = time.time() - start_trial
        self.trial_history.append((config, score, elapsed))

        trial_id = len(self.trial_history) - 1
        print(f"[Hyperopt] Trial {trial_id} finished with value: {score:.5f} and parameters: {config}")

        return {"loss": self._convert_score(score), "status": STATUS_OK}

    def optimize(self) -> List[Tuple[Dict[str, Any], Any, float]]:
        self.trial_history = []
        self._start_time = time.time()
        try:
            fmin(
                fn=self._objective_wrapper,
                space=self.space,
                algo=self.algo,
                max_evals=self.max_trials,
                trials=self.trials,
                rstate=None,
            )
        except Exception as e:
            print(f"[HyperoptOptimizer] Optimization failed: {e}")
        self.elapsed_time = time.time() - self._start_time
        return self.trial_history

    def _map_choice_indices_back(self, config: Dict[str, Any], space: Dict[str, Any]) -> Dict[str, Any]:
        mapped = {}
        for k, v in config.items():
            space_node = space.get(k)

            # Only map if it is a hp.choice node
            if isinstance(space_node, Apply) and space_node.name == "switch":
                choices = space_node.pos_args[1:]  # skip the first arg (index)
                if isinstance(choices[v], Literal):
                    mapped[k] = choices[v].obj
                else:
                    mapped[k] = v
            else:
                mapped[k] = v

        return mapped

    def get_best_config(self, include_score: bool = False) -> Dict[str, Any]:
        best = self.trials.best_trial
        raw_config = {k: v[0] for k, v in best["misc"]["vals"].items()}
        mapped_config = self._map_choice_indices_back(raw_config, self.space)

        if include_score:
            score = best["result"]["loss"]
            score = -score if self.mode == "max" else score
            return {
                "params": mapped_config,
                "score": score,
                "elapsed_time": (round(self.elapsed_time, 4) if self.elapsed_time else None),
            }

        return mapped_config
